{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "58a16aa6-be23-41b6-a182-d4f8e64d62b5",
   "metadata": {},
   "source": [
    "# Example 02 - Sig53 Classifier\n",
    "This notebook walks through a simple example of how to use the clean Sig53 dataset, load a pre-trained supported model, and evaluate the trained network's performance. Note that the experiment and the results herein are not to be interpreted with any significant value but rather serve simply as a practical example of how the `torchsig` dataset and tools can be used and integrated within a typical [PyTorch](https://pytorch.org/) and/or [PyTorch Lightning](https://www.pytorchlightning.ai/) workflow.\n",
    "\n",
    "----"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c205f51b-0be8-4158-935c-31969333713f",
   "metadata": {},
   "source": [
    "## Import Libraries\n",
    "First, import all the necessary public libraries as well as a few classes from the `torchsig` toolkit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64d06b6b-6b91-4e20-a7c4-e9953c36a8f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsig.transforms.target_transforms import DescToClassIndex\n",
    "from torchsig.utils.writer import DatasetCreator\n",
    "from torchsig.transforms.transforms import (\n",
    "    RandomPhaseShift,\n",
    "    Normalize,\n",
    "    ComplexTo2D,\n",
    "    Compose,\n",
    ")\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "from pytorch_lightning import LightningModule, Trainer\n",
    "from sklearn.metrics import classification_report\n",
    "from torchsig.utils.cm_plotter import plot_confusion_matrix\n",
    "from torchsig.datasets.sig53 import Sig53\n",
    "from torchsig.datasets.modulations import ModulationsDataset\n",
    "from torchsig.datasets.datamodules import Sig53DataModule\n",
    "from torch.nn import CrossEntropyLoss\n",
    "from torchsig.datasets import conf\n",
    "from torch import optim\n",
    "import pytorch_lightning as pl\n",
    "from tqdm import tqdm\n",
    "import torch.nn.functional as F\n",
    "from torchmetrics import Accuracy\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from torchsig.models import EfficientNet1d"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1dd97813-33b2-4e23-b29f-218d35c487fa",
   "metadata": {},
   "source": [
    "----\n",
    "### Instantiate Sig53 Dataset\n",
    "Here, we instantiate the Sig53 clean training dataset and the Sig53 clean validation dataset. We demonstrate how to compose multiple TorchSig transforms together, using a data impairment with a random phase shift that uniformly samples a phase offset between -1 pi and +1 pi. The next transform normalizes the complex tensor, and the final transform converts the complex data to a real-valued tensor with the real and imaginary parts as two channels. We additionally provide a target transform that maps the `SignalMetadata` objects, that are part of `SignalData` objects, to a desired format for the model we will train. In this case, we use the `DescToClassIndex` target transform to map class names to their indices within an ordered class list. Finally, we sample from our datasets and print details in order to confirm functionality.\n",
    "\n",
    "For more details on the Sig53 dataset instantiations, please see `00_example_sig53_dataset.ipynb`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf7a2414-0470-4538-939d-85e65f55f6f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify Sig53 Options\n",
    "root = \"./datasets/wideband_sig53\"\n",
    "impaired = False\n",
    "class_list = list(Sig53._idx_to_name_dict.values())\n",
    "batch_size = 64\n",
    "num_workers = 4\n",
    "\n",
    "transform = Compose(\n",
    "    [\n",
    "        RandomPhaseShift(phase_offset=(-1, 1)),\n",
    "        Normalize(norm=np.inf),\n",
    "        ComplexTo2D(),\n",
    "    ]\n",
    ")\n",
    "target_transform = DescToClassIndex(class_list=class_list)\n",
    "\n",
    "datamodule = Sig53DataModule(\n",
    "    root=root,\n",
    "    impaired=False,\n",
    "    transform=transform,\n",
    "    target_transform=target_transform,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=num_workers\n",
    ")\n",
    "datamodule.prepare_data()\n",
    "datamodule.setup(\"fit\")\n",
    "\n",
    "# Retrieve a sample and print out information to verify\n",
    "idx = np.random.randint(len(datamodule.train))\n",
    "data, label = datamodule.train[idx]\n",
    "print(\"Dataset length: {}\".format(len(datamodule.train)))\n",
    "print(\"Data shape: {}\".format(data.shape))\n",
    "print(\"Label Index: {}\".format(label))\n",
    "print(\"Label Class: {}\".format(Sig53.convert_idx_to_name(label)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72f37459-1136-4214-890c-a44279f7367f",
   "metadata": {},
   "source": [
    "----\n",
    "## Instantiate Supported TorchSig Model\n",
    "Below, we create a 1d EfficientNet-B0 model, and then conform it to a PyTorch LightningModule for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8c34578-2740-4bd2-aeb4-c48b58821fa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = EfficientNet1d(2,53)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eade4610-b4ca-4900-99e7-e85c6b0f9892",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExampleClassifier(pl.LightningModule):\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.mdl = model\n",
    "        self.loss_fn = CrossEntropyLoss()\n",
    "        self.acc = Accuracy(task=\"multiclass\",num_classes=53,top_k=5)\n",
    "    def forward(self,batch):\n",
    "        return self.mdl(batch[0].float()).argmax(-1),batch[1]\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.Adam(self.parameters())\n",
    "    \n",
    "    def training_step(self, batch):\n",
    "        x, y = (batch[0].float(),batch[1])\n",
    "        y_hat = self.mdl(x)\n",
    "        return self.loss_fn(y_hat, y)\n",
    "\n",
    "    def validation_step(self, batch):\n",
    "        x, y = (batch[0].float(),batch[1])\n",
    "        y_hat = self.mdl(x)\n",
    "        loss = self.loss_fn(y_hat, y)\n",
    "        self.acc(y_hat,y)\n",
    "        results = {\"val/loss\": loss, \"val/acc\": self.acc}\n",
    "        self.log_dict(results, prog_bar=True)\n",
    "        return results\n",
    "example_model = ExampleClassifier(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33551e2c-cc3b-4384-93fb-28bb55680726",
   "metadata": {},
   "source": [
    "----\n",
    "## Train the Model\n",
    "To train the model, we first create a `ModelCheckpoint` to monitor the validation loss over time and save the best model as we go. The network is then instantiated and passed into a `Trainer` to kick off training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1503a4af-cf87-4a4b-ba39-ac46afe8e628",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(\n",
    "    max_epochs=25,\n",
    ")\n",
    "trainer.fit(example_model, datamodule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfda7603",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = trainer.predict(example_model,datamodule.val_dataloader())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c1cac6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_hat = torch.cat([p[0] for p in preds])\n",
    "y = torch.cat([p[1] for p in preds])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf330a55",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "accuracy_per_class = []\n",
    "for cls in range(53):\n",
    "    # Select predictions and labels for the current class\n",
    "    cls_predictions = y_hat[y == cls]\n",
    "    cls_labels = y[y == cls]\n",
    "    \n",
    "    # Calculate accuracy\n",
    "    accuracy = torch.mean((cls_predictions == cls_labels).float())\n",
    "    accuracy_per_class.append(accuracy)\n",
    "\n",
    "# Plot the accuracy per class\n",
    "plt.figure(figsize=(20, 5))\n",
    "plt.bar(class_list, accuracy_per_class)\n",
    "plt.xlabel('Class')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.title('Accuracy per Class')\n",
    "plt.xticks(class_list,rotation=70)\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
